import torch.nn as nn
import torch
import torch.nn.functional as F
import torch.nn.init as init

def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True)

def conv_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        init.xavier_uniform_(m.weight, gain=np.sqrt(2))
        init.constant_(m.bias, 0)
    elif classname.find('BatchNorm') != -1:
        init.constant_(m.weight, 1)
        init.constant_(m.bias, 0)

class SimpleCNN_b(nn.Module):
    def __init__(self, shape, num_classes):
        super(SimpleCNN_b, self).__init__()

        self.num_classes = num_classes

        # Activation function
        self.act = nn.ReLU(inplace=True)

        # Normalization layer
        self.norm = nn.BatchNorm2d(128, affine=True)
        self.norm1 = nn.BatchNorm2d(128, affine=True)
        self.norm2 = nn.BatchNorm2d(128, affine=True)
        self.norm3 = nn.BatchNorm2d(128, affine=True)

        # Pooling layer
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

        # Convolutional layers with dynamic input channels and width
        self.conv1 = nn.Conv2d(in_channels=shape[0], out_channels=128, kernel_size=3, stride=1, padding=1)
        # self.norm1 = self.norm(128)

        self.conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        # self.norm2 = self.norm(128)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        # self.norm3 = self.norm(128)

        # Dynamic size calculation (assuming square input and pooling)
        reduced_size = shape[1] // (2**3)  # Adjust based on the number of pooling layers used
        self.fc = nn.Linear(128 * reduced_size * reduced_size, num_classes)

    def forward(self, x):
        x = self.act(self.norm1(self.conv1(x)))
        x = self.pool(x)

        x = self.act(self.norm2(self.conv2(x)))
        x = self.pool(x)

        x = self.act(self.norm3(self.conv3(x)))
        x = self.pool(x)

        # flatten the tensor
        x = x.view(x.size(0), -1)

        x = self.fc(x)

        return x

class SimpleCNN(nn.Module):
    def __init__(self, shape, num_classes):
        super(SimpleCNN, self).__init__()

        self.num_classes = num_classes

        # Activation function
        self.act = nn.ReLU(inplace=True)

        # Normalization layer
        self.norm = nn.InstanceNorm2d

        # Pooling layer
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)

        # Convolutional layers with dynamic input channels and width
        self.conv1 = nn.Conv2d(in_channels=shape[0], out_channels=128, kernel_size=3, stride=1, padding=1)
        self.norm1 = self.norm(128)

        self.conv2 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.norm2 = self.norm(128)

        self.conv3 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.norm3 = self.norm(128)

        # Dynamic size calculation (assuming square input and pooling)
        reduced_size = shape[1] // (2**3)  # Adjust based on the number of pooling layers used
        self.fc = nn.Linear(128 * reduced_size * reduced_size, num_classes)

    def forward(self, x):
        x = self.act(self.norm1(self.conv1(x)))
        x = self.pool(x)

        x = self.act(self.norm2(self.conv2(x)))
        x = self.pool(x)

        x = self.act(self.norm3(self.conv3(x)))
        x = self.pool(x)

        # flatten the tensor
        x = x.view(x.size(0), -1)

        x = self.fc(x)

        return x
    def feature_extractor(self, x):
        x = self.act(self.norm1(self.conv1(x)))
        x = self.pool(x)

        x = self.act(self.norm2(self.conv2(x)))
        x = self.pool(x)

        x = self.act(self.norm3(self.conv3(x)))
        x = self.pool(x)

        # flatten the tensor
        x = x.view(x.size(0), -1)

        return x


class wide_basic(nn.Module):
    def __init__(self, in_planes, planes, dropout_rate, stride=1):
        super(wide_basic, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
        self.dropout = nn.Dropout(p=dropout_rate)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=True)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
            )

    def forward(self, x, get_bn = False):
        bn_out = self.dropout(self.conv1(F.relu(self.bn1(x))))
        out = self.conv2(F.relu(self.bn2(bn_out)))
        out += self.shortcut(x)
        if get_bn:
            return out, bn_out
        return out

        
class Wide_ResNet(nn.Module):
    def __init__(self, depth, widen_factor, dropout_rate, num_classes):
        super(Wide_ResNet, self).__init__()
        print('making wideresnet')
        self.in_planes = 16

        assert ((depth-4)%6 ==0), 'Wide-resnet depth should be 6n+4'
        n = (depth-4)/6
        k = widen_factor

        print('| Wide-Resnet %dx%d' %(depth, k))
        nStages = [16, 16*k, 32*k, 64*k]

        self.conv1 = conv3x3(3,nStages[0])
        self.layer1 = self._wide_layer(wide_basic, nStages[1], n, dropout_rate, stride=1)
        self.layer2 = self._wide_layer(wide_basic, nStages[2], n, dropout_rate, stride=2)
        self.layer3 = self._wide_layer(wide_basic, nStages[3], n, dropout_rate, stride=2)
        self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
        self.linear = nn.Linear(nStages[3], num_classes)

    def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
        strides = [stride] + [1]*(int(num_blocks)-1)
        layers = []

        for stride in strides:
            layers.append(block(self.in_planes, planes, dropout_rate, stride))
            self.in_planes = planes

        return nn.Sequential(*layers)

    def calculate_diff(self, x, mean, variance):
        mean_x = torch.mean(x)
        var_x = x.var()
        loss_mean = torch.dist(mean_x, mean, p=2)
        loss_var = torch.dist(var_x, variance, p=2)
        
        return loss_mean + loss_var

    def forward(self, x, get_bn = False):
        out = self.conv1(x)
        if get_bn:
            # mean, variance = self.bn1.running_mean.clone(), self.bn1.running_var.clone()
            first_bn_running_mean = self.layer1[0].bn1.running_mean 
            first_bn_running_var = self.layer1[0].bn1.running_var
            diff = self.calculate_diff(out, first_bn_running_mean, first_bn_running_var)
            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            # diff = self.calculate_diff(out, mean, variance)
            out = F.relu(self.bn1(out))
            out = F.avg_pool2d(out, 8)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
            return out, diff
        else:
            out = self.layer1(out)
            out = self.layer2(out)
            out = self.layer3(out)
            out = F.relu(self.bn1(out))
            out = F.avg_pool2d(out, 8)
            out = out.view(out.size(0), -1)
            out = self.linear(out)
        return out

class TransformerEncoder(nn.Module):
    def __init__(self, feats:int, mlp_hidden:int, head:int=8, dropout:float=0.):
        super(TransformerEncoder, self).__init__()
        self.la1 = nn.LayerNorm(feats)
        self.msa = MultiHeadSelfAttention(feats, head=head, dropout=dropout)
        self.la2 = nn.LayerNorm(feats)
        self.mlp = nn.Sequential(
            nn.Linear(feats, mlp_hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_hidden, feats),
            nn.GELU(),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        out = self.msa(self.la1(x)) + x
        out = self.mlp(self.la2(out)) + out
        return out


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, feats:int, head:int=8, dropout:float=0.):
        super(MultiHeadSelfAttention, self).__init__()
        self.head = head
        self.feats = feats
        self.sqrt_d = self.feats**0.5

        self.q = nn.Linear(feats, feats)
        self.k = nn.Linear(feats, feats)
        self.v = nn.Linear(feats, feats)

        self.o = nn.Linear(feats, feats)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        b, n, f = x.size()
        q = self.q(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)
        k = self.k(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)
        v = self.v(x).view(b, n, self.head, self.feats//self.head).transpose(1,2)

        score = F.softmax(torch.einsum("bhif, bhjf->bhij", q, k)/self.sqrt_d, dim=-1) #(b,h,n,n)
        attn = torch.einsum("bhij, bhjf->bihf", score, v) #(b,n,h,f//h)
        o = self.dropout(self.o(attn.flatten(2)))
        return o

class MultiHeadDepthwiseSelfAttention(nn.Module):
    def __init__(self, feats:int, head:int=8, dropout:float=0):
        super(MultiHeadDepthwiseSelfAttention, self).__init__()
        ...

    def forward(self, x):
        
        ...

class ViT(nn.Module):
    def __init__(self, in_c:int=3, num_classes:int=10, img_size:int=32, patch:int=8, dropout:float=0., num_layers:int=7, hidden:int=384, mlp_hidden:int=384, head:int=12, is_cls_token:bool=True):
        super(ViT, self).__init__()
        # hidden=384

        self.patch = patch # number of patches in one row(or col)
        self.is_cls_token = is_cls_token
        self.patch_size = img_size//self.patch
        f = (img_size//self.patch)**2*3 # 48 # patch vec length
        num_tokens = (self.patch**2)+1 if self.is_cls_token else (self.patch**2)

        self.emb = nn.Linear(f, hidden) # (b, n, f)
        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden)) if is_cls_token else None
        self.pos_emb = nn.Parameter(torch.randn(1,num_tokens, hidden))
        enc_list = [TransformerEncoder(hidden,mlp_hidden=mlp_hidden, dropout=dropout, head=head) for _ in range(num_layers)]
        self.enc = nn.Sequential(*enc_list)
        self.fc = nn.Sequential(
            nn.LayerNorm(hidden),
            nn.Linear(hidden, num_classes) # for cls_token
        )


    def forward(self, x):
        out = self._to_words(x)
        out = self.emb(out)
        if self.is_cls_token:
            out = torch.cat([self.cls_token.repeat(out.size(0),1,1), out],dim=1)
        out = out + self.pos_emb
        out = self.enc(out)
        if self.is_cls_token:
            out = out[:,0]
        else:
            out = out.mean(1)
        out = self.fc(out)
        return out

    def _to_words(self, x):
        """
        (b, c, h, w) -> (b, n, f)
        """
        out = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size).permute(0,2,3,4,5,1)
        out = out.reshape(x.size(0), self.patch**2 ,-1)
        return out

